{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bayesian Dark Knowledge in MXNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we will show how to implement Bayesian Dark Knowledge <a name=\"ref-1\"/>[(Korattikara, Rathod, Murphy and Welling, 2015)](#cite-korattikara2015bayesian) in MXNet.\n",
    "\n",
    "In applications like recommendation and control, bayesian treatment of neural networks may be helpful in that we can \n",
    "model the uncertainty of our prediction to avoid overconfident actions <a name=\"ref-2\"/>[(Yeung, Hao and Naiyan, 2015)](#cite-bdl). However, bayesian parameter estimation is non-trivial and much more difficult than a simple point estimation due to the high-dimensionality and non-linearity of neural networks. One way to tackle the problem is the expectation propagation approach in <a name=\"ref-3\"/>[(Hern&aacute;ndez-Lobato and Adams, 2015)](#cite-hernandez2015probabilistic), which relies on a predefined parameteric form of the posterior distribution. The Bayesian Dark Knowledge (BDK) implemented in this notebook is another solution that uses Stochastic Gradient Langevin Dynamics (SGLD) to draw samples from the posterior of the bayesian neural network and fit a student network use these teaching samples. BDK can achieve similar performance as the SGLD teacher while being much faster for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "import mxnet.ndarray as nd\n",
    "import numpy\n",
    "import time\n",
    "import ssl\n",
    "import os\n",
    "\n",
    "\n",
    "def load_mnist(training_num=50000):\n",
    "    data_path = os.path.join(os.path.dirname(os.path.realpath('__file__')), 'mnist.npz')\n",
    "    if not os.path.isfile(data_path):\n",
    "        from six.moves import urllib\n",
    "        origin = (\n",
    "            'https://github.com/sxjscience/mxnet/raw/master/example/bayesian-methods/mnist.npz'\n",
    "        )\n",
    "        print('Downloading data from %s to %s' % (origin, data_path))\n",
    "        urllib.request.urlretrieve(origin, data_path)\n",
    "        print('Done!')\n",
    "    dat = numpy.load(data_path)\n",
    "    X = (dat['X'][:training_num] / 126.0).astype('float32')\n",
    "    Y = dat['Y'][:training_num]\n",
    "    X_test = (dat['X_test'] / 126.0).astype('float32')\n",
    "    Y_test = dat['Y_test']\n",
    "    Y = Y.reshape((Y.shape[0],))\n",
    "    Y_test = Y_test.reshape((Y_test.shape[0],))\n",
    "    return X, Y, X_test, Y_test\n",
    "\n",
    "\n",
    "def sample_test_acc(exe, X, Y, label_num=None, minibatch_size=100):\n",
    "    pred = numpy.zeros((X.shape[0], label_num)).astype('float32')\n",
    "    iter = mx.io.NDArrayIter(data=X, label=Y, batch_size=minibatch_size, shuffle=False)\n",
    "    curr_instance = 0\n",
    "    iter.reset()\n",
    "    for batch in iter:\n",
    "        exe.arg_dict['data'][:] = batch.data[0]\n",
    "        exe.forward(is_train=False)\n",
    "        batch_size = minibatch_size - batch.pad\n",
    "        pred[curr_instance:curr_instance + minibatch_size - batch.pad, :] += exe.outputs[0].asnumpy()[:batch_size]\n",
    "        curr_instance += batch_size\n",
    "    correct = (pred.argmax(axis=1) == Y).sum()\n",
    "    total = Y.shape[0]\n",
    "    acc = correct/float(total)\n",
    "    return correct, total, acc\n",
    "\n",
    "\n",
    "def get_executor(sym, ctx, data_inputs, initializer=None):\n",
    "    data_shapes = {k: v.shape for k, v in data_inputs.items()}\n",
    "    arg_names = sym.list_arguments()\n",
    "    aux_names = sym.list_auxiliary_states()\n",
    "    param_names = list(set(arg_names) - set(data_inputs.keys()))\n",
    "    arg_shapes, output_shapes, aux_shapes = sym.infer_shape(**data_shapes)\n",
    "    arg_name_shape = {k: s for k, s in zip(arg_names, arg_shapes)}\n",
    "    params = {n: nd.empty(arg_name_shape[n], ctx=ctx) for n in param_names}\n",
    "    params_grad = {n: nd.empty(arg_name_shape[n], ctx=ctx) for n in param_names}\n",
    "    aux_states = {k: nd.empty(s, ctx=ctx) for k, s in zip(aux_names, aux_shapes)}\n",
    "    exe = sym.bind(ctx=ctx, args=dict(params, **data_inputs),\n",
    "                   args_grad=params_grad,\n",
    "                   aux_states=aux_states)\n",
    "    if initializer != None:\n",
    "        for k, v in params.items():\n",
    "            initializer(k, v)\n",
    "    return exe, params, params_grad, aux_states\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After defining some helper functions, we will go on implementing the real-staffs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def DistilledSGLD(teacher_sym, student_sym,\n",
    "                  teacher_data_inputs, student_data_inputs,\n",
    "                  X, Y, X_test, Y_test, total_iter_num,\n",
    "                  teacher_learning_rate, student_learning_rate,\n",
    "                  teacher_lr_scheduler=None, student_lr_scheduler=None,\n",
    "                  student_optimizing_algorithm='adam',\n",
    "                  teacher_prior_precision=1, student_prior_precision=0.001,\n",
    "                  perturb_deviation=0.001,\n",
    "                  student_initializer=None,\n",
    "                  teacher_initializer=None,\n",
    "                  minibatch_size=100,\n",
    "                  dev=mx.gpu()):\n",
    "    teacher_exe, teacher_params, teacher_params_grad, _ = \\\n",
    "        get_executor(teacher_sym, dev, teacher_data_inputs, teacher_initializer)\n",
    "    student_exe, student_params, student_params_grad, _ = \\\n",
    "        get_executor(student_sym, dev, student_data_inputs, student_initializer)\n",
    "    teacher_label_key = list(set(teacher_data_inputs.keys()) - set(['data']))[0]\n",
    "    student_label_key = list(set(student_data_inputs.keys()) - set(['data']))[0]\n",
    "    teacher_optimizer = mx.optimizer.create('sgld',\n",
    "                                            learning_rate=teacher_learning_rate,\n",
    "                                            rescale_grad=X.shape[0] / float(minibatch_size),\n",
    "                                            lr_scheduler=teacher_lr_scheduler,\n",
    "                                            wd=teacher_prior_precision)\n",
    "    student_optimizer = mx.optimizer.create(student_optimizing_algorithm,\n",
    "                                            learning_rate=student_learning_rate,\n",
    "                                            rescale_grad=1.0 / float(minibatch_size),\n",
    "                                            lr_scheduler=student_lr_scheduler,\n",
    "                                            wd=student_prior_precision)\n",
    "    teacher_updater = mx.optimizer.get_updater(teacher_optimizer)\n",
    "    student_updater = mx.optimizer.get_updater(student_optimizer)\n",
    "    start = time.time()\n",
    "    for i in range(total_iter_num):\n",
    "        # 1.1 Draw random minibatch\n",
    "        indices = numpy.random.randint(X.shape[0], size=minibatch_size)\n",
    "        X_batch = X[indices]\n",
    "        Y_batch = Y[indices]\n",
    "        \n",
    "        # 1.2 Update teacher\n",
    "        teacher_exe.arg_dict['data'][:] = X_batch\n",
    "        teacher_exe.arg_dict[teacher_label_key][:] = Y_batch\n",
    "        teacher_exe.forward(is_train=True)\n",
    "        teacher_exe.backward()       \n",
    "        for k in teacher_params:\n",
    "            teacher_updater(k, teacher_params_grad[k], teacher_params[k])\n",
    "    \n",
    "        # 2.1 Draw random minibatch and do random perturbation\n",
    "        indices = numpy.random.randint(X.shape[0], size=minibatch_size)\n",
    "        X_student_batch = X[indices] + numpy.random.normal(0, perturb_deviation, X_batch.shape).astype('float32')\n",
    "\n",
    "        # 2.2 Get teacher predictions\n",
    "        teacher_exe.arg_dict['data'][:] = X_student_batch\n",
    "        teacher_exe.forward(is_train=False)\n",
    "        teacher_pred = teacher_exe.outputs[0]\n",
    "        teacher_pred.wait_to_read()\n",
    "\n",
    "        # 2.3 Update student\n",
    "        student_exe.arg_dict['data'][:] = X_student_batch\n",
    "        student_exe.arg_dict[student_label_key][:] = teacher_pred\n",
    "        student_exe.forward(is_train=True)\n",
    "        student_exe.backward()\n",
    "        for k in student_params:\n",
    "            student_updater(k, student_params_grad[k], student_params[k])\n",
    "\n",
    "        if (i + 1) % 2000 == 0:\n",
    "            end = time.time()\n",
    "            print(\"Current Iter Num: %d\" % (i + 1), \"Time Spent: %f\" % (end - start))\n",
    "            test_correct, test_total, test_acc = \\\n",
    "                sample_test_acc(student_exe, X=X_test, Y=Y_test, label_num=10,\n",
    "                                minibatch_size=minibatch_size)\n",
    "            train_correct, train_total, train_acc = \\\n",
    "                sample_test_acc(student_exe, X=X, Y=Y, label_num=10,\n",
    "                                minibatch_size=minibatch_size)\n",
    "            teacher_test_correct, teacher_test_total, teacher_test_acc = \\\n",
    "                sample_test_acc(teacher_exe, X=X_test, Y=Y_test, label_num=10,\n",
    "                                minibatch_size=minibatch_size)\n",
    "            teacher_train_correct, teacher_train_total, teacher_train_acc = \\\n",
    "                sample_test_acc(teacher_exe, X=X, Y=Y, label_num=10,\n",
    "                                minibatch_size=minibatch_size)\n",
    "            print(\"Student: Test %d/%d=%f, Train %d/%d=%f\" % (test_correct, test_total, test_acc,\n",
    "                                                       train_correct, train_total, train_acc))\n",
    "            print(\"Teacher: Test %d/%d=%f, Train %d/%d=%f\" \\\n",
    "                  % (teacher_test_correct, teacher_test_total, teacher_test_acc,\n",
    "                     teacher_train_correct, teacher_train_total, teacher_train_acc))\n",
    "            start = time.time()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can train a student network using 500 sample from the MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class CrossEntropySoftmax(mx.operator.NumpyOp):\n",
    "    def __init__(self):\n",
    "        super(CrossEntropySoftmax, self).__init__(False)\n",
    "\n",
    "    def list_arguments(self):\n",
    "        return ['data', 'label']\n",
    "\n",
    "    def list_outputs(self):\n",
    "        return ['output']\n",
    "\n",
    "    def infer_shape(self, in_shape):\n",
    "        data_shape = in_shape[0]\n",
    "        label_shape = in_shape[0]\n",
    "        output_shape = in_shape[0]\n",
    "        return [data_shape, label_shape], [output_shape]\n",
    "\n",
    "    def forward(self, in_data, out_data):\n",
    "        x = in_data[0]\n",
    "        y = out_data[0]\n",
    "        y[:] = numpy.exp(x - x.max(axis=1).reshape((x.shape[0], 1))).astype('float32')\n",
    "        y /= y.sum(axis=1).reshape((x.shape[0], 1))\n",
    "\n",
    "    def backward(self, out_grad, in_data, out_data, in_grad):\n",
    "        l = in_data[1]\n",
    "        y = out_data[0]\n",
    "        dx = in_grad[0]\n",
    "        dx[:] = (y - l)\n",
    "\n",
    "        \n",
    "class BiasXavier(mx.initializer.Xavier):\n",
    "    def _init_bias(self, _, arr):\n",
    "        scale = numpy.sqrt(self.magnitude / arr.shape[0])\n",
    "        mx.random.uniform(-scale, scale, out=arr)\n",
    "        \n",
    "        \n",
    "def get_mnist_sym(output_op=None, num_hidden=400):\n",
    "    net = mx.symbol.Variable('data')\n",
    "    net = mx.symbol.FullyConnected(data=net, name='mnist_fc1', num_hidden=num_hidden)\n",
    "    net = mx.symbol.Activation(data=net, name='mnist_relu1', act_type=\"relu\")\n",
    "    net = mx.symbol.FullyConnected(data=net, name='mnist_fc2', num_hidden=num_hidden)\n",
    "    net = mx.symbol.Activation(data=net, name='mnist_relu2', act_type=\"relu\")\n",
    "    net = mx.symbol.FullyConnected(data=net, name='mnist_fc3', num_hidden=10)\n",
    "    if output_op is None:\n",
    "        net = mx.symbol.SoftmaxOutput(data=net, name='softmax')\n",
    "    else:\n",
    "        net = output_op(data=net, name='softmax')\n",
    "    return net\n",
    "\n",
    "def dev():\n",
    "    return mx.gpu()\n",
    "\n",
    "def run_mnist_DistilledSGLD(training_num=50000):\n",
    "    X, Y, X_test, Y_test = load_mnist(training_num)\n",
    "    minibatch_size = 100\n",
    "    if training_num >= 10000:\n",
    "        num_hidden = 800\n",
    "        total_iter_num = 1000000\n",
    "        teacher_learning_rate = 1E-6\n",
    "        student_learning_rate = 0.0001\n",
    "        teacher_prior = 1\n",
    "        student_prior = 0.1\n",
    "        perturb_deviation = 0.1\n",
    "    else:\n",
    "        num_hidden = 400\n",
    "        total_iter_num = 20000\n",
    "        teacher_learning_rate = 4E-5\n",
    "        student_learning_rate = 0.0001\n",
    "        teacher_prior = 1\n",
    "        student_prior = 0.1\n",
    "        perturb_deviation = 0.001\n",
    "    teacher_net = get_mnist_sym(num_hidden=num_hidden)\n",
    "    crossentropy_softmax = CrossEntropySoftmax()\n",
    "    student_net = get_mnist_sym(output_op=crossentropy_softmax, num_hidden=num_hidden)\n",
    "    data_shape = (minibatch_size,) + X.shape[1::]\n",
    "    teacher_data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),\n",
    "                           'softmax_label': nd.zeros((minibatch_size,), ctx=dev())}\n",
    "    student_data_inputs = {'data': nd.zeros(data_shape, ctx=dev()),\n",
    "                           'softmax_label': nd.zeros((minibatch_size, 10), ctx=dev())}\n",
    "    teacher_initializer = BiasXavier(factor_type=\"in\", magnitude=1)\n",
    "    student_initializer = BiasXavier(factor_type=\"in\", magnitude=1)\n",
    "    DistilledSGLD(teacher_sym=teacher_net, student_sym=student_net,\n",
    "                  teacher_data_inputs=teacher_data_inputs,\n",
    "                  student_data_inputs=student_data_inputs,\n",
    "                  X=X, Y=Y, X_test=X_test, Y_test=Y_test, total_iter_num=total_iter_num,\n",
    "                  student_initializer=student_initializer,\n",
    "                  teacher_initializer=teacher_initializer,\n",
    "                  student_optimizing_algorithm=\"adam\",\n",
    "                  teacher_learning_rate=teacher_learning_rate,\n",
    "                  student_learning_rate=student_learning_rate,\n",
    "                  teacher_prior_precision=teacher_prior, student_prior_precision=student_prior,\n",
    "                  perturb_deviation=perturb_deviation, minibatch_size=100, dev=dev())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Current Iter Num: 2000 Time Spent: 71.816000\n",
      "Student: Test 8456/10000=0.845600, Train 500/500=1.000000\n",
      "Teacher: Test 6945/10000=0.694500, Train 495/500=0.990000\n",
      "Current Iter Num: 4000 Time Spent: 68.814000\n",
      "Student: Test 8557/10000=0.855700, Train 500/500=1.000000\n",
      "Teacher: Test 7187/10000=0.718700, Train 497/500=0.994000\n",
      "Current Iter Num: 6000 Time Spent: 68.610000\n",
      "Student: Test 8559/10000=0.855900, Train 500/500=1.000000\n",
      "Teacher: Test 7248/10000=0.724800, Train 499/500=0.998000\n",
      "Current Iter Num: 8000 Time Spent: 68.404000\n",
      "Student: Test 8541/10000=0.854100, Train 500/500=1.000000\n",
      "Teacher: Test 7320/10000=0.732000, Train 500/500=1.000000\n",
      "Current Iter Num: 10000 Time Spent: 68.503000\n",
      "Student: Test 8488/10000=0.848800, Train 500/500=1.000000\n",
      "Teacher: Test 7264/10000=0.726400, Train 499/500=0.998000\n",
      "Current Iter Num: 12000 Time Spent: 68.417000\n",
      "Student: Test 8585/10000=0.858500, Train 500/500=1.000000\n",
      "Teacher: Test 7593/10000=0.759300, Train 500/500=1.000000\n",
      "Current Iter Num: 14000 Time Spent: 68.342000\n",
      "Student: Test 8602/10000=0.860200, Train 500/500=1.000000\n",
      "Teacher: Test 7563/10000=0.756300, Train 500/500=1.000000\n",
      "Current Iter Num: 16000 Time Spent: 68.484000\n",
      "Student: Test 8559/10000=0.855900, Train 500/500=1.000000\n",
      "Teacher: Test 7345/10000=0.734500, Train 497/500=0.994000\n",
      "Current Iter Num: 18000 Time Spent: 69.646000\n",
      "Student: Test 8523/10000=0.852300, Train 500/500=1.000000\n",
      "Teacher: Test 7618/10000=0.761800, Train 500/500=1.000000\n",
      "Current Iter Num: 20000 Time Spent: 68.815000\n",
      "Student: Test 8637/10000=0.863700, Train 500/500=1.000000\n",
      "Teacher: Test 7600/10000=0.760000, Train 500/500=1.000000\n"
     ]
    }
   ],
   "source": [
    "numpy.random.seed(100)\n",
    "mx.random.seed(100)\n",
    "run_mnist_DistilledSGLD(500)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!--bibtex\n",
    "\n",
    "@inproceeding{korattikara2015bayesian,\n",
    "  title={Bayesian dark knowledge},\n",
    "  author={Korattikara, Anoop and Rathod, Vivek and Murphy, Kevin and Welling, Max},\n",
    "  journal={NIPS 2015},\n",
    "  year={2015},\n",
    "  url=\"papers.nips.cc/paper/5965-bayesian-dark-knowledge.pdf\"\n",
    "}\n",
    "\n",
    "@inproceeding{hernandez2015probabilistic,\n",
    "  title={Probabilistic backpropagation for scalable learning of bayesian neural networks},\n",
    "  author={Hern{\\'a}ndez-Lobato, Jos{\\'e} Miguel and Adams, Ryan P},\n",
    "  journal={ICML 2015},\n",
    "  year={2015},\n",
    "  url=\"http://jmlr.org/proceedings/papers/v37/hernandez-lobatoc15.pdf\"\n",
    "}\n",
    "\n",
    "@misc{bdl,\n",
    "  Author = {Yeung, Dit-Yan and Hao, Wang and Naiyan, Wang and Xingjian Shi},\n",
    "  Institution = {Hong Kong University of Science and Technology},\n",
    "  Howpublished = {ACML 2015 Talk},\n",
    "  Year = {2015},\n",
    "  Title = {Bayesian deep learning for integrated intelligence: bridging the gap between perception and inference},\n",
    "  url=\"http://www.wanghao.in/mis/BDL_ACML.pdf\"\n",
    "}\n",
    "\n",
    "-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# References\n",
    "\n",
    "<a name=\"cite-korattikara2015bayesian\"/><sup>[^](#ref-1) </sup>Korattikara, Anoop and Rathod, Vivek and Murphy, Kevin and Welling, Max. 2015. _Bayesian dark knowledge_. [URL](papers.nips.cc/paper/5965-bayesian-dark-knowledge.pdf)\n",
    "\n",
    "<a name=\"cite-bdl\"/><sup>[^](#ref-2) </sup>Yeung, Dit-Yan and Hao, Wang and Naiyan, Wang and Xingjian Shi. 2015. _Bayesian deep learning for integrated intelligence: bridging the gap between perception and inference_. [URL](http://www.wanghao.in/mis/BDL_ACML.pdf)\n",
    "\n",
    "<a name=\"cite-hernandez2015probabilistic\"/><sup>[^](#ref-3) </sup>Hern&aacute;ndez-Lobato, Jos&eacute; Miguel and Adams, Ryan P. 2015. _Probabilistic backpropagation for scalable learning of bayesian neural networks_. [URL](http://jmlr.org/proceedings/papers/v37/hernandez-lobatoc15.pdf)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
